import math
import copy
import gym
import random
import numpy as np
import statistics
import pickle
import argparse
from pathlib import Path

# Ensure the env ID "ImprovedWalker2d-v0" is registered.
# If your file is named differently, change this import accordingly.
import improved_walker2d  # noqa: F401

from SnapshotENV import SnapshotEnv

# -------------------------------
# Progressive Widening Node (same core logic as your PW script)
# -------------------------------
discount = 0.99

class NodePW:
    def __init__(self, snapshot, obs, is_done, parent, depth, min_action, max_action, dim):
        self.parent = parent
        self.snapshot = snapshot
        self.obs = obs
        self.is_done = is_done
        self.depth = depth

        self.children = []      # child nodes
        self.actions = []       # actions tried (parallel to children)
        self.immediate_reward = 0.0

        self.min_action = min_action
        self.max_action = max_action
        self.dim = dim

        self.visit_count = 0
        self.value_sum = 0.0
        self.action = None  # filled for non-root nodes

    def get_mean_value(self):
        return 0.0 if self.visit_count == 0 else (self.value_sum / self.visit_count)

    def select_child(self):
        best_child = None
        best_score = -1e9
        N = max(1, self.visit_count)
        c = 1.414  # sqrt(2)
        for child in self.children:
            mean_val = child.get_mean_value()
            n = max(child.visit_count, 1)
            ucb = mean_val + c * math.sqrt(math.log(N) / n)
            if ucb > best_score:
                best_score = ucb
                best_child = child
        return best_child

    def expand(self):
        # dimension-aware progressive widening budget
        if self.dim <= 3:
            K = max(5, int(0.5 * math.sqrt(self.visit_count)))
        else:
            K = max(3, int(0.3 * math.sqrt(self.visit_count)))

        if len(self.children) < K:
            action = tuple(random.uniform(self.min_action, self.max_action) for _ in range(self.dim))
            child_node = NodePW(
                snapshot=None,
                obs=None,
                is_done=False,
                parent=self,
                depth=self.depth + 1,
                min_action=self.min_action,
                max_action=self.max_action,
                dim=self.dim
            )
            child_node.action = action
            self.children.append(child_node)
            self.actions.append(action)
            return child_node
        return None

    def rollout(self, env, max_depth):
        """Simple random rollout."""
        if self.depth >= max_depth:
            return 0.0
        if self.snapshot:
            env.load_snapshot(self.snapshot)

        total = 0.0
        df = 1.0
        for _ in range(max_depth - self.depth):
            action = tuple(random.uniform(self.min_action, self.max_action) for _ in range(self.dim))
            _, r, done, _ = env.step(action)
            total += df * r
            df *= discount
            if done:
                break
        return total

    def selection(self, env, max_depth):
        if self.is_done or self.depth >= max_depth:
            return 0.0

        # Progressive widening: try a new child if allowed
        new_child = self.expand()
        if new_child:
            # Ensure we have a snapshot to branch from
            if self.snapshot is None:
                self.snapshot = env.get_snapshot()
            res = env.get_result(self.snapshot, new_child.action)
            new_child.snapshot = res.snapshot
            new_child.obs = res.next_state
            new_child.is_done = res.is_done
            new_child.immediate_reward = res.reward

            # For high-dim action spaces (Walker2d has dim=6), include a one-shot rollout
            value = new_child.rollout(env, max_depth)
            child_return = new_child.immediate_reward + value

            self.visit_count += 1
            self.value_sum += child_return
            return child_return
        else:
            best = self.select_child()
            if best is None:
                return 0.0
            child_return = best.immediate_reward + best.selection(env, max_depth)
            self.visit_count += 1
            self.value_sum += child_return
            return child_return


def run_one(env_id: str, iters: int, seeds: int, max_depth: int,
            action_low: float, action_high: float, action_dim: int,
            test_horizon: int):
    """Runs PW for a single ITER count over multiple seeds; returns (mean, std, per_seed list)."""
    # Build a planning env (snapshottable)
    plan_env = SnapshotEnv(gym.make(env_id).env)
    root_obs0 = plan_env.reset()
    root_snap0 = plan_env.get_snapshot()

    seed_returns = []

    for seed in range(seeds):
        random.seed(seed)
        np.random.seed(seed)

        # copy root state
        root_obs = copy.copy(root_obs0)
        root_snap = copy.copy(root_snap0)

        # build root and plan
        root = NodePW(
            snapshot=root_snap,
            obs=root_obs,
            is_done=False,
            parent=None,
            depth=0,
            min_action=action_low,
            max_action=action_high,
            dim=action_dim
        )

        for _ in range(iters):
            root.selection(plan_env, max_depth)

        # Test-time execution on the cloned state
        test_env = pickle.loads(root_snap)
        total_reward = 0.0
        df = 1.0
        done = False

        for _ in range(test_horizon):
            if not root.children:
                best_action = tuple(random.uniform(action_low, action_high) for _ in range(action_dim))
            else:
                best_child = max(root.children, key=lambda c: c.get_mean_value())
                best_action = best_child.action

            _, r, done, _ = test_env.step(best_action)
            total_reward += df * r
            df *= discount
            if done:
                test_env.close()
                break

            # prune all other branches; keep the chosen child (re-root)
            keep = None
            for c in root.children:
                if c.action == best_action:
                    keep = c
                else:
                    # delete subtrees to free memory
                    for gc in c.children:
                        pass  # allow GC; no deep explicit delete required in Python
            root.children = [keep] if keep is not None else []

            if keep is None:
                keep = NodePW(
                    snapshot=None,
                    obs=None,
                    is_done=False,
                    parent=None,
                    depth=0,
                    min_action=action_low,
                    max_action=action_high,
                    dim=action_dim
                )
                keep.action = best_action

            # re-root
            root = keep
            root.parent = None
            root.depth = 0

            # re-plan at the new root
            for _ in range(iters):
                root.selection(plan_env, max_depth)

        if not done:
            test_env.close()

        seed_returns.append(total_reward)

    mean_ret = statistics.mean(seed_returns)
    std_ret = statistics.pstdev(seed_returns) if len(seed_returns) > 1 else 0.0
    return mean_ret, std_ret, seed_returns


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--env", default="ImprovedWalker2d-v0", type=str)
    p.add_argument("--iters", nargs="+", type=int, default=[3, 4, 7, 11, 18, 29])
    p.add_argument("--seeds", type=int, default=20)
    p.add_argument("--test_horizon", type=int, default=150)
    p.add_argument("--max_depth", type=int, default=100)
    p.add_argument("--out_txt", type=str, default="pw_walker2d_results.txt")
    p.add_argument("--out_csv", type=str, default="pw_walker2d_results.csv")
    args = p.parse_args()

    # Walker2d action space: Box([-1,1], shape=(6,))
    env_id = args.env
    action_low = -1.0
    action_high = 1.0
    action_dim = 6

    print(f"Env: {env_id}")
    print(f"Seeds: {args.seeds}")
    print(f"ITER list: {args.iters}")
    print(f"Action dim: {action_dim}, range: [{action_low}, {action_high}]")
    print(f"Max planning depth: {args.max_depth}, Test horizon: {args.test_horizon}")

    Path(args.out_txt).parent.mkdir(parents=True, exist_ok=True)
    Path(args.out_csv).parent.mkdir(parents=True, exist_ok=True)

    with open(args.out_txt, "a") as ftxt, open(args.out_csv, "w") as fcsv:
        fcsv.write("env,iter,seeds,mean,std\n")
        for it in args.iters:
            mean_ret, std_ret, _ = run_one(
                env_id=env_id,
                iters=it,
                seeds=args.seeds,
                max_depth=args.max_depth,
                action_low=action_low,
                action_high=action_high,
                action_dim=action_dim,
                test_horizon=args.test_horizon
            )
            line = f"Env={env_id}, ITER={it}: Mean={mean_ret:.3f} ± {2.0*std_ret:.3f} (over {args.seeds} seeds)"
            print(line)
            ftxt.write(line + "\n")
            ftxt.flush()
            fcsv.write(f"{env_id},{it},{args.seeds},{mean_ret:.6f},{std_ret:.6f}\n")
            fcsv.flush()

    print(f"Done. Wrote {args.out_txt} and {args.out_csv}.")


if __name__ == "__main__":
    main()
